{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Numpy实现K折交叉验证的数据划分"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本实例使用Numpy的数组切片语法，实现了K折交叉验证的数据划分"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 背景：K折交叉验证\n",
    "\n",
    "***为什么需要这个？***  \n",
    "在机器学习中，因为如下原因，使用K折交叉验证能更好评估模型效果：\n",
    "1. 样本量不充足，划分了训练集和测试集后，训练数据更少；\n",
    "2. 训练集和测试集的不同划分，可能会导致不同的模型性能结果；\n",
    "\n",
    "\n",
    "***K折验证是什么***  \n",
    "K折验证（K-fold validtion）将数据划分为大小相同的K个分区。  \n",
    "对每个分区i，在剩余的K-1个分区上训练模型，然后在分区i上评估模型。  \n",
    "最终分数等于K个分数的平均值，使用平均值来消除训练集和测试集的划分影响；\n",
    "\n",
    "<img src=\"./other_files/numpy-kfold-validation.png\" style=\"margin-left:0px; width:60%;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 模拟构造样本集合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0,  1,  2,  3],\n",
       "       [ 4,  5,  6,  7],\n",
       "       [ 8,  9, 10, 11],\n",
       "       [12, 13, 14, 15],\n",
       "       [16, 17, 18, 19],\n",
       "       [20, 21, 22, 23],\n",
       "       [24, 25, 26, 27],\n",
       "       [28, 29, 30, 31],\n",
       "       [32, 33, 34, 35]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = np.arange(36).reshape(9,4)\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用样本的角度解释下data数组：\n",
    "* 这是一个二维矩阵，行代表每个样本，列代表每个特征\n",
    "* 这里有9个样本，每个样本有4个特征\n",
    "\n",
    "这是scikit-learn模型训练输入的标准格式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 使用Numpy实现K次划分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 我们想进行4折交叉验证\n",
    "k = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 算出来每个fold的样本个数\n",
    "k_samples_count = data.shape[0]//k\n",
    "k_samples_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "#####第0折#####\n",
      "验证集：\n",
      " [[0 1 2 3]\n",
      " [4 5 6 7]]\n",
      "训练集：\n",
      " [[ 8  9 10 11]\n",
      " [12 13 14 15]\n",
      " [16 17 18 19]\n",
      " [20 21 22 23]\n",
      " [24 25 26 27]\n",
      " [28 29 30 31]\n",
      " [32 33 34 35]]\n",
      "\n",
      "#####第1折#####\n",
      "验证集：\n",
      " [[ 8  9 10 11]\n",
      " [12 13 14 15]]\n",
      "训练集：\n",
      " [[ 0  1  2  3]\n",
      " [ 4  5  6  7]\n",
      " [16 17 18 19]\n",
      " [20 21 22 23]\n",
      " [24 25 26 27]\n",
      " [28 29 30 31]\n",
      " [32 33 34 35]]\n",
      "\n",
      "#####第2折#####\n",
      "验证集：\n",
      " [[16 17 18 19]\n",
      " [20 21 22 23]]\n",
      "训练集：\n",
      " [[ 0  1  2  3]\n",
      " [ 4  5  6  7]\n",
      " [ 8  9 10 11]\n",
      " [12 13 14 15]\n",
      " [24 25 26 27]\n",
      " [28 29 30 31]\n",
      " [32 33 34 35]]\n",
      "\n",
      "#####第3折#####\n",
      "验证集：\n",
      " [[24 25 26 27]\n",
      " [28 29 30 31]]\n",
      "训练集：\n",
      " [[ 0  1  2  3]\n",
      " [ 4  5  6  7]\n",
      " [ 8  9 10 11]\n",
      " [12 13 14 15]\n",
      " [16 17 18 19]\n",
      " [20 21 22 23]\n",
      " [32 33 34 35]]\n"
     ]
    }
   ],
   "source": [
    "for fold in range(k):\n",
    "    validation_begin = k_samples_count*fold\n",
    "    validation_end = k_samples_count*(fold+1)\n",
    "    \n",
    "    validation_data = data[validation_begin:validation_end]\n",
    "    \n",
    "    # np.vstack，沿着垂直的方向堆叠数组\n",
    "    train_data = np.vstack([\n",
    "        data[:validation_begin], \n",
    "        data[validation_end:]\n",
    "    ])\n",
    "    \n",
    "    print()\n",
    "    print(f\"#####第{fold}折#####\")\n",
    "    print(\"验证集：\\n\", validation_data)\n",
    "    print(\"训练集：\\n\", train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果使用scikit-learn，已经有封装好的实现：  \n",
    "from sklearn.model_selection import cross_val_score"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
